package org.broadinstitute.hellbender.tools.dragstr;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterator;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.commons.io.output.NullOutputStream;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.ReadsPathDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller;
import org.broadinstitute.hellbender.transformers.DRAGENMappingQualityReadTransformer;
import org.broadinstitute.hellbender.transformers.ReadTransformer;
import org.broadinstitute.hellbender.utils.BinaryTableReader;
import org.broadinstitute.hellbender.utils.IntervalMergingRule;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SequenceDictionaryUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.collections.AutoCloseableCollection;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParamUtils;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParams;
import org.broadinstitute.hellbender.utils.dragstr.STRTableFile;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.AbsoluteCoordinates;

import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.SAMFlag;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.util.IntervalTree;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import scala.Tuple4;

/**
 * Estimates the parameters for the DRAGstr model for an input sample.
 * <p>
 *     This tools takes in the sampling sites generated by {@link ComposeSTRTableFile} on the same reference
 *     as the input sample.
 * </p>
 * <p>
 *     The end result is a text file containing three parameter tables (GOP, GCP, API) that can be fed
 *     directly to {@link HaplotypeCaller} --dragstr-params-path.
 * </p>
 */
@CommandLineProgramProperties(
        summary = "estimates the parameters for the DRAGstr model for the input sample using the output of the ComposeSTRTable tool",
        oneLineSummary = "estimates the parameters for the DRAGstr model",
        programGroup = ShortVariantDiscoveryProgramGroup.class
)
@DocumentedFeature
public class CalibrateDragstrModel extends GATKTool {

    public static final String STR_TABLE_PATH_SHORT_NAME = "str";
    public static final String STR_TABLE_PATH_FULL_NAME = "str-table-path";
    public static final String PARALLEL_FULL_NAME = "parallel";
    public static final String THREADS_FULL_NAME = "threads";
    public static final String SHARD_SIZE_FULL_NAME = "shard-size";
    public static final String DOWN_SAMPLE_SIZE_FULL_NAME = "down-sample-size";
    public static final String DEBUG_SITES_OUTPUT_FULL_NAME = "debug-sites-output";
    public static final String FORCE_ESTIMATION_FULL_NAME = "force-estimation";

    public static final int DEFAULT_SHARD_SIZE = 1_000_000;
    public static final int DEFAULT_DOWN_SAMPLE_SIZE = 4096;
    public static final int SYSTEM_SUGGESTED_THREAD_NUMBER = 0;
    public static final int MINIMUM_SHARD_SIZE = 100;
    public static final int MINIMUM_DOWN_SAMPLE_SIZE = 512;

    @ArgumentCollection
    private DragstrHyperParameters hyperParameters = new DragstrHyperParameters();

    @Argument(shortName=STR_TABLE_PATH_SHORT_NAME, fullName=STR_TABLE_PATH_FULL_NAME, doc="location of the zip that contains the sampling sites for the reference")
    private GATKPath strTablePath = null;

    @Argument(fullName=PARALLEL_FULL_NAME, doc="run alignment data collection and  estimation in parallel", optional = true)
    private boolean runInParallel = false;

    @Argument(fullName=THREADS_FULL_NAME, minValue = SYSTEM_SUGGESTED_THREAD_NUMBER, doc="suggested number of parallel threads to perform the estimation, "
            + "the default 0 leave it up to the VM to decide. When set to more than 1, this will activate parallel in the absence of --parallel", optional = true)
    private int threads = SYSTEM_SUGGESTED_THREAD_NUMBER;

    @Argument(fullName=SHARD_SIZE_FULL_NAME, doc="when running in parallel this is the suggested shard size in base pairs. " +
            "The actual shard-size may vary to adapt to small contigs and the requested number of threads",
              minValue = MINIMUM_SHARD_SIZE, optional = true)
    private int shardSize = DEFAULT_SHARD_SIZE;

    @Argument(fullName=DOWN_SAMPLE_SIZE_FULL_NAME, doc="Targeted maximum number of cases per combination period repeat count, " +
            "the larger the more precise but also the slower estimation.",
              minValue = MINIMUM_DOWN_SAMPLE_SIZE, optional = true)
    private int downsampleSize = DEFAULT_DOWN_SAMPLE_SIZE;

    @Argument(fullName= StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "where to write the parameter output file.")
    private GATKPath output = null;

    @Argument(fullName= DEBUG_SITES_OUTPUT_FULL_NAME, doc = "table with information gather on the samples sites. Includes what sites were downsampled, disqualified or accepted for parameter estimation", optional = true)
    private String sitesOutput = null;

    @Argument(fullName= FORCE_ESTIMATION_FULL_NAME, doc = "for testing purpose only; force parameter estimation even with few datapoints available", optional = true)
    private boolean forceEstimation = false;

    private SAMSequenceDictionary dictionary;
    private SamReaderFactory factory;
    public static final ReadTransformer EXTENDED_MQ_READ_TRANSFORMER = new DRAGENMappingQualityReadTransformer();

    @Override
    public boolean requiresReference() {
        return true;
    }

    @Override
    public boolean requiresReads() {
        return true;
    }

    @Override
    protected void onStartup() {
        super.onStartup();
        hyperParameters.validate();
        dictionary = directlyAccessEngineReadsDataSource().getSequenceDictionary();
        factory = makeSamReaderFactory();

        if (runInParallel) {
            if (threads == 1) {
                logger.warn("parallel processing was requested but the number of threads was set to 1");
            }
        } else if (threads > 1) {
            runInParallel = true;
        }
        if (runInParallel) {
            if (threads == 0) {
                logger.info("Running in parallel using the system suggested default thread count: " + Runtime.getRuntime().availableProcessors());
            } else {
                logger.info("Running in parallel using the requested number of threads: " + threads);
            }
        }
    }

    @Override
    public void traverse() {
        hyperParameters.validate();
        dictionary = getBestAvailableSequenceDictionary();
        final List<SAMReadGroupRecord> readGroups = hasReads() ? getHeaderForReads().getReadGroups() : Collections.emptyList();
        final List<String> readGroupIds = readGroups.stream()
                .map(SAMReadGroupRecord::getId)
                .collect(Collectors.toList());
        final List<String> sampleNames = readGroups.stream()
                .map(SAMReadGroupRecord::getSample)
                .distinct().collect(Collectors.toList());
        final Optional<String> sampleName = resolveSampleName(sampleNames);

        try (final PrintWriter sitesOutputWriter = openSitesOutputWriter(sitesOutput);
             final STRTableFile strTable = STRTableFile.open(strTablePath)) {

            checkSequenceDictionaryCompatibility(dictionary, strTable.dictionary());
            final StratifiedDragstrLocusCases allSites;
            final List<SimpleInterval> intervals = getTraversalIntervals();

            runInParallel |= threads > 1;
            if (runInParallel) {
                if (threads == 1) {
                    logger.warn("parallel processing was requested but the number of threads was set to 1");
                }
                allSites = collectCaseStatsParallel(intervals, shardSize, strTable);
            } else {
                allSites = collectCaseStatsSequencial(intervals, strTable);
            }
            logSiteCounts(allSites, "all loci/cases");
            final StratifiedDragstrLocusCases downSampledSites = downSample(allSites, strTable, sitesOutputWriter);
            logSiteCounts(downSampledSites, "all downsampled (kept) loci/cases");
            final StratifiedDragstrLocusCases finalSites = downSampledSites.qualifyingOnly(hyperParameters.minDepth, hyperParameters.minMQ, 0);
            logSiteCounts(finalSites, "all qualifying loci/cases");
            outputDownSampledSiteDetails(downSampledSites, sitesOutputWriter, hyperParameters.minDepth, hyperParameters.minMQ, 0);
            printOutput(finalSites, sampleName.orElse(null), readGroupIds);
        }
    }

    private void printOutput(final StratifiedDragstrLocusCases finalSites, final String sampleName, final List<String> readGroups) {
        final boolean enoughCases = isThereEnoughCases(finalSites);
        final boolean usingDefaults = !enoughCases && !forceEstimation;
        final Object[] annotations = {
                "sample", (sampleName == null ? "<unspecified>" : sampleName),
                "readGroups", (readGroups.isEmpty() ? "<unspecified>" : Utils.join(", ", readGroups)),
                "estimatedOrDefaults", (usingDefaults ? "defaults" : (enoughCases ? "estimated" : "estimatedByForce")),
                "commandLine", getCommandLine()
        };
        if (!usingDefaults) {
            if (!enoughCases) {
                logger.warn("Forcing parameters estimation using sampled down cases as requested");
            } else {
                logger.info("Estimating parameters using sampled down cases");
            }
            final DragstrParams estimate = estimateParams(finalSites);
            logger.info("Done with estimation, printing output");
            DragstrParamUtils.print(estimate, output, annotations);

        } else {
            logger.warn("Not enough cases to estimate parameters, using defaults");
            DragstrParamUtils.print(DragstrParams.DEFAULT, output, annotations);
        }
    }

    private Optional<String> resolveSampleName(List<String> sampleNames) {
        if (sampleNames.size() > 1) {
            throw new GATKException("the input alignment(s) have more than one sample: " + String.join(", ", sampleNames));
        } else if (sampleNames.isEmpty() || sampleNames.get(0) == null) {
            logger.warn("there is no sample id in the alignment header, assuming that all reads and read/groups make reference to the same anonymous sample");
            return Optional.empty();
        } else {
            return Optional.of(sampleNames.get(0));
        }
    }

    private void checkSequenceDictionaryCompatibility(final SAMSequenceDictionary reference, final SAMSequenceDictionary strTable) {
        final SequenceDictionaryUtils.SequenceDictionaryCompatibility compatibility = SequenceDictionaryUtils.compareDictionaries(reference, strTable, false);
        switch (compatibility) {
            case IDENTICAL: return;
            case SUPERSET: return;
            // probably these two below aren't ever be returned since we ask for no check on order but
            // adding them it just in case
            case NON_CANONICAL_HUMAN_ORDER: return; // we don't care about the order.
            case OUT_OF_ORDER: return; // we don't care about the order.
            default:
                throw new GATKException("the reference and str-table sequence dictionary are incompatible: " + compatibility);
        }
    }

    @SuppressWarnings("deprecation")
    private PrintWriter openSitesOutputWriter(final String sitesOutput) {
        return sitesOutput == null ? new PrintWriter(NullOutputStream.NULL_OUTPUT_STREAM)
                : new PrintWriter(BucketUtils.createFile(sitesOutput));
    }

    private void outputDownSampledSiteDetails(final StratifiedDragstrLocusCases finalSites,
                                              final PrintWriter writer,
                                              final int minDepth,
                                              final int samplingMinMQ,
                                              final int maxSup) {
        if (sitesOutput != null) {
            for (final DragstrLocusCases[] periodCases : finalSites.perPeriodAndRepeat) {
                for (final DragstrLocusCases repeatCases : periodCases) {
                    for (final DragstrLocusCase caze : repeatCases) {
                        outputSiteDetails(writer, caze, caze.qualifies(minDepth, samplingMinMQ, maxSup) ? "used" : "skipped");
                    }
                }
            }
        }
    }

    /**
     * Holds the minimum counts for each period, repeat-length combo.
     * If there is lack of data for any of these we use the default param
     * tables. Missing values, row (periods) or columns (repeat-length) are
     * interpreted as 0.
     */
    private static final int[][] MINIMUM_CASES_BY_PERIOD_AND_LENGTH =
            // @formatter:off ; prevents code reformatting by IntelliJ
            //                  if enabled:
            //                    Preferences > Editor > Code Style > Formatter Control
            // run-length:
            //  0,   1,   2,   3,   4,   5,   6,   7,   8,   9, 10+   // period
            {  {},
               {0, 200, 200, 200, 200, 200, 200, 200, 200, 200,   0}, // 1
               {0,   0, 200, 200, 200, 200,   0,   0,   0,   0,   0}, // 2
               {0,   0, 200, 200, 200,   0,   0,   0,   0,   0,   0}, // 3
               {0,   0, 200, 200,   0,   0,   0,   0,   0,   0,   0}, // 4
               {0,   0, 200,   0,   0,   0,   0,   0,   0,   0,   0}, // 5
               {0,   0, 200,   0,   0,   0,   0,   0,   0,   0,   0}, // 6
               {0,   0, 200,   0,   0,   0,   0,   0,   0,   0,   0}, // 7
               {0,   0, 200,   0,   0,   0,   0,   0,   0,   0,   0}, // 8
            };
            // zeros to the right are actually not necessary, but add them to make it look more like a matrix.
            // @formatter:on

    /**
     * Check that a minimum number of cases are available in key bins (combo period, repeat).
     */
    private boolean isThereEnoughCases(final StratifiedDragstrLocusCases allSites) {
        // period 1, repeat length 1 to 9 (inclusive)
            final int[][] MCBL = MINIMUM_CASES_BY_PERIOD_AND_LENGTH;
            final int maxP = Math.min(hyperParameters.maxPeriod, MCBL.length - 1);
            final List<Tuple4<Integer, Integer, Integer, Integer>> failingCombos = new ArrayList<>(10);
            for (int i = 1; i <= maxP; i++) {
                final int maxL = Math.min(hyperParameters.maxRepeatLength, MCBL[i].length - 1);
                for (int j = 1; j <= maxL; j++) {
                    if (allSites.get(i, j).size() < MCBL[i][j]) {
                        failingCombos.add(new Tuple4<>(i, j, allSites.get(i, j).size(), MCBL[i][j]));
                    }
                }
            }
            if (failingCombos.isEmpty()) {
                return true;
            } else if (forceEstimation) {
                logger.warn("there is not enough data to proceed to parameter empirical estimation " +
                        "but user requested to force it, so we go ahead");
                for (final Tuple4<Integer, Integer, Integer, Integer> failingCombo : failingCombos) {
                    logger.warn(String.format("(P=%d, L=%d) count %d is less than minimum required %d ",
                            failingCombo._1(), failingCombo._2(), failingCombo._3(), failingCombo._4()));
                }
                return true;
            } else {
                logger.warn("there is not enough data to proceed to parameter empirical estimation, using defaults instead");
                return false;
            }
    }

    /**
     * Performs the final estimation step.
     * @param finalSites the site to use for the estimation.
     * @return {@code never null}.
     */
    private DragstrParams estimateParams(final StratifiedDragstrLocusCases finalSites) {
        final DragstrParametersEstimator estimator = new DragstrParametersEstimator(hyperParameters);
        return runInParallel ? Utils.runInParallel(threads, () -> estimator.estimate(finalSites)) : estimator.estimate(finalSites);
    }

    /**
     * Downsample sites so that at most as many as {@link #downsampleSize} cases remain for each period and repeat-length combination.
     * @param allSites the sites to downsample.
     * @param strTable that contains the decimation table used to generate those sites.
     * @param sitesOutputWriter an optional per site informattion output argument for debugging purposes.
     * @return never {@code null}.
     */
    private StratifiedDragstrLocusCases downSample(final StratifiedDragstrLocusCases allSites, final STRTableFile strTable,
                                                   final PrintWriter sitesOutputWriter) {
        final STRDecimationTable decimationTable = strTable.decimationTable();
        final List<PeriodAndRepeatLength> prCombos = new ArrayList<>(hyperParameters.maxPeriod * hyperParameters.maxRepeatLength);
        for (int i = 1; i <= hyperParameters.maxPeriod; i++) {
            for (int j = 1; j <= hyperParameters.maxRepeatLength; j++) {
                prCombos.add(PeriodAndRepeatLength.of(i, j));
            }
        }

        final Stream<PeriodAndRepeatLength> prCombosStream = runInParallel ? prCombos.parallelStream() : prCombos.stream();
        final Stream<DragstrLocusCase> downsampledStream = prCombosStream
                .flatMap(combo -> {
                    final DragstrLocusCases all = allSites.perPeriodAndRepeat[combo.period - 1][combo.repeatLength - 1];
                    final int decimationBit = decimationTable.decimationBit(combo.period, combo.repeatLength);
                    return downSample(all, decimationBit, downsampleSize, sitesOutputWriter).stream();
                });

        if (runInParallel) {
            return Utils.runInParallel(threads,
                    () -> downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength)));
        } else {
            return downsampledStream.collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength));
        }
    }

    /**
     * Pre-calculated decimation masks used depending on the final decimation bit/level.
     */
    private static final long[] DECIMATION_MASKS_BY_BIT = new long[Long.SIZE];

    // Code to populate DECIMATION_MASKS_BY_BIT.
    static {
        DECIMATION_MASKS_BY_BIT[0] = 1;
        for (int i = 1, j = 0; i < Long.SIZE; i++, j++) {
            DECIMATION_MASKS_BY_BIT[i] = DECIMATION_MASKS_BY_BIT[j] << 1;
            DECIMATION_MASKS_BY_BIT[j] = ~DECIMATION_MASKS_BY_BIT[j];
        }
        DECIMATION_MASKS_BY_BIT[Long.SIZE -1] = ~DECIMATION_MASKS_BY_BIT[Long.SIZE - 1];
    }

    /**
     * Decimates the collection of locus/cases to the downsample size provided or smaller.
     * <p>
     *     Notice that if we need to downsample (the input size is larger than the downsample size provided)
     *     we take care of not counting those cases that have zero-length toward that limit.
     *     This is due to the apparent behaviour in DRAGEN where those are "sort-of" filtered before
     *     decimation as far as meeting the final downsample size limit is concerned.
     * </p>
     * <p>
     *     They usually would be skipped eventually in post-downsampling filtering but we don't consider
     *     their number here we end up downsampling some period, repeat-length combintions too much
     *     as compare to DRAGEN.
     * </p>
     * <p>
     *     This behavior in DRAGEN may well change in future releases.
     * </p>
     * @param in input collection of cases to downsample.
     * @param minDecimationBit The start decimation bit. Usually the input cases collection won't contain any
     *                         cases with lower bit set (already decimated).
     * @param downsampleSize the target size.
     * @return never {@code null}. At most the return would contain {@code downsampleSize} cases discounting cases with zero depth. It could be empty.
     */
    private DragstrLocusCases downSample(final DragstrLocusCases in, final int minDecimationBit, final int downsampleSize, final PrintWriter sitesOutputWriter) {
        final int inSize = in.size();
        if (inSize <= downsampleSize) { // we already satisfy the imposed size limit so we do nothing.
            return in;
        } else {
            int zeroDepth = 0;
            final int[] countByFirstDecimatingBit = new int[Long.SIZE - minDecimationBit];
            for (final DragstrLocusCase caze: in) {
                final DragstrLocus locus = caze.getLocus();
                final int depth = caze.getDepth();
                if (depth <= 0) { // we discount cases with zero depth as these are going to be skipped eventually.
                    zeroDepth++;
                    continue;
                }
                long mask = locus.getMask();
                for (int j = minDecimationBit; mask != 0 && j < Long.SIZE;  j++) {
                    final long newMask = mask & DECIMATION_MASKS_BY_BIT[j];
                    if (newMask != mask) {
                         countByFirstDecimatingBit[j]++;
                         break;
                    }
                }
            }

            final IntList progressiveSizes = new IntArrayList(Long.SIZE + 1);
            progressiveSizes.add(inSize);
            int finalSize = inSize - zeroDepth;
            progressiveSizes.add(finalSize);
            long filterMask = 0;
            for (int j = minDecimationBit; finalSize > downsampleSize && j < Long.SIZE; j++) {
                finalSize -= countByFirstDecimatingBit[j];
                filterMask |= ~DECIMATION_MASKS_BY_BIT[j];
                progressiveSizes.add(finalSize);
            }
            final DragstrLocusCases discarded = new DragstrLocusCases(finalSize, in.getPeriod(), in.getRepeatLength());
            final DragstrLocusCases result = new DragstrLocusCases(in.size() - finalSize, in.getPeriod(), in.getRepeatLength());
            for (final DragstrLocusCase caze: in) {
                final long mask = caze.getLocus().getMask();
                if ((mask & filterMask) == 0 & caze.getDepth() > 0) {
                    discarded.add(caze);
                } else {
                    result.add(caze);
                }
            }

            // Debug-log message format explained:
            // period repeat-length [x0, x00, x1, x2, x3 ... xN]
            // where x0 is the input size.
            //       x00 = x0 - #zero depth cases
            //       x1 = x00 - #first round of decimation
            //       x2 = x1  - #second round of decimation.
            //       ...
            //       xN = final size <= downsampleSize
            logger.debug(() -> "" + in.getPeriod() + " "  + in.getRepeatLength() + " "
                    + Arrays.toString(progressiveSizes.toArray()));

            // we output info about the sites that are discarded:
            if (sitesOutput != null && result.size() > 0) {
                synchronized (this) {
                    for (final DragstrLocusCase caze : result) {
                        outputSiteDetails(sitesOutputWriter, caze, "downsampled-out");
                    }
                }
            }
            return discarded;
        }
    }

    /**
     * Logs cases counts in a matrix where columns are periods and rows are
     *  repeat length in repeat units.
     * @param cases the cases whose counts are to be logged.
     * @param title the title of the debug message.
     */
    private void logSiteCounts(final StratifiedDragstrLocusCases cases, final String title) {
        if (logger.isDebugEnabled()) { // here it seems pertinent to check to save time if DEBUG is off since
                                       // this method is all about debug logging.
            logger.debug(title);
            final int[] columnWidths = IntStream.range(1, hyperParameters.maxPeriod + 1).map(period -> {
                final int max = IntStream.range(1, hyperParameters.maxRepeatLength + 1).map(repeat -> cases.get(period,repeat).size())
                        .max().orElse(0);
                return (int) Math.max(7, Math.ceil(Math.log10(max)) + 1); }).toArray();
            logger.debug("      " + IntStream.range(0, hyperParameters.maxPeriod).mapToObj(i -> String.format("%-" + columnWidths[i] + "s", (i + 1))).collect(Collectors.joining()));
            for (int i = 1; i <= hyperParameters.maxRepeatLength; i++) {
                final int repeat = i;
                logger.debug(String.format("%-4s", repeat) + "  " + IntStream.range(1, hyperParameters.maxPeriod + 1)
                        .mapToObj(period -> String.format("%-" + columnWidths[period - 1] + "s",
                                cases.get(period, repeat).size())).collect(Collectors.joining("")));
            }
        }
    }

    private StratifiedDragstrLocusCases collectCaseStatsSequencial(final List<SimpleInterval> intervals, final STRTableFile strTable) {
        final StratifiedDragstrLocusCases result = StratifiedDragstrLocusCases.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength);
        final ReadsDataSource dataSource = directlyAccessEngineReadsDataSource();
        for (final SimpleInterval interval : intervals) {
            try (final BinaryTableReader<DragstrLocus> reader = strTable.locusReader(interval)) {
                streamShardCasesStats(interval, readStream(dataSource, interval), reader.stream())
                        .peek(caze -> progressMeter.update(caze.getLocation(dictionary)))
                        .forEach(result::add);
            } catch (final IOException ex) {
                throw new GATKException("problems accessing str-table-file at " + strTablePath);
            }
        }
        return result;
    }

    @SuppressWarnings("try") // silences intended use of unreferenced auto-closable within try-resource.
    private StratifiedDragstrLocusCases collectCaseStatsParallel(final List<SimpleInterval> intervals, final int shardSize, final STRTableFile strTable) {

        //TODO: instead of the dictionary this should take on the traversal intervals.
        //TODO: currently, in the user specifies intervals the progress-meter will show that aount of bases at the beginning of the reference
        //  instead.
        final AbsoluteCoordinates absoluteCoordinates = AbsoluteCoordinates.of(dictionary);

        final List<SimpleInterval> shards = shardIntervals(intervals, shardSize);

        final Collection<ReadsPathDataSource> readSources = new Vector<>(threads);
        final ThreadLocal<ReadsPathDataSource> threadReadSource = ThreadLocal.withInitial(
                () -> {
                    final ReadsPathDataSource result = new ReadsPathDataSource(readArguments.getReadPaths(), factory);
                    readSources.add(result);
                    return result;
                });

        try (@SuppressWarnings("unused") final AutoCloseableCollection<?> readSourceCloser = new AutoCloseableCollection<>(readSources)) {
            final AtomicLong numberBasesProcessed = new AtomicLong(0);
            return Utils.runInParallel(Math.min(threads, shards.size()), () ->
                    StreamSupport.stream(new InterleavingListSpliterator<>(shards), true)
                            .map(shard -> {
                                try (final BinaryTableReader<DragstrLocus> lociReader = strTable.locusReader(shard)) {
                                    final ReadsPathDataSource readsSource = threadReadSource.get();
                                    final StratifiedDragstrLocusCases result = streamShardCasesStats(shard, readStream(readsSource, shard), lociReader.stream())
                                            .collect(DragstrLocusCaseStratificator.make(hyperParameters.maxPeriod, hyperParameters.maxRepeatLength));
                                    final int resultSize = result.size();
                                    synchronized (numberBasesProcessed) {
                                        final long processed = numberBasesProcessed.updateAndGet(l -> l + shard.size());
                                        progressMeter.update(absoluteCoordinates.toSimpleInterval(processed, 1), resultSize);
                                    }
                                    return result;
                                } catch (final IOException ex) {
                                    throw new GATKException("problems accessing the str-table-file contents at " + strTablePath, ex);
                                }
                            })
                            .reduce(StratifiedDragstrLocusCases::merge)
                            .orElseGet(() -> new StratifiedDragstrLocusCases(
                                    hyperParameters.maxPeriod,
                                    hyperParameters.maxRepeatLength)));
        }
    }

    /**
     * Shards the traversal intervals based on the requested target shard size.
     * @param raw the unprocessed traversal intervals.
     * @param shardSize the target shard size in base-pairs.
     * @return never {@code null}, but perhaps empty if the input was empty.
     */
    private List<SimpleInterval> shardIntervals(final List<SimpleInterval> raw, final int shardSize) {
        final List<SimpleInterval> preSharded = sortAndMergeOverlappingIntervals(raw, dictionary);
        final long size = preSharded.stream().mapToLong(SimpleInterval::size).sum();
        final List<SimpleInterval> output = new ArrayList<>((int) (preSharded.size() + size / shardSize));
        // if less than 1.5 x the desired shard size is left in the current interval we don't split any further:
        final int shardingSizeThreshold = (int) Math.round(shardSize * 1.5);
        for (final SimpleInterval in : preSharded) {
            if (in.size() < shardingSizeThreshold) {
                output.add(in);
            } else {
                int start = in.getStart();
                final int inEnd = in.getEnd();
                final int stop = in.getEnd() - shardingSizeThreshold + 1;
                while (start < stop) {
                    final int end = start + shardSize - 1;
                    output.add(new SimpleInterval(in.getContig(), start, end));
                    start = end + 1;
                }
                if (start <= inEnd) {
                    output.add(new SimpleInterval(in.getContig(), start, inEnd));
                }
            }
        }
        return output;
    }

    /**
     * If the traversal interval contains some overlaps we need to fix it.
     */
    private List<SimpleInterval> sortAndMergeOverlappingIntervals(final List<SimpleInterval> input, final SAMSequenceDictionary dictionary) {
        if (isSortedAndHasNoOverlap(input, dictionary)) {
            return input;
        } else {
            final Map<String, List<SimpleInterval>> byContig = IntervalUtils.sortAndMergeIntervals(input, dictionary, IntervalMergingRule.ALL);
            return byContig.keySet().stream()
                    .sorted(Comparator.comparingInt(name -> dictionary.getSequence(name).getSequenceIndex()))
                    .flatMap(name -> byContig.get(name).stream())
                    .collect(Collectors.toList());
        }
    }

    private boolean isSortedAndHasNoOverlap(final List<SimpleInterval> input, final SAMSequenceDictionary dictionary) {
        if (input.isEmpty()) {
            return true;
        } else {
            String prevCtgName = null;
            int prevCtgIdx = -1;
            int prevEnd = 0;
            for (final SimpleInterval interval : input) {
                final String ctg = interval.getContig();
                final int start = interval.getStart();
                final int end = interval.getEnd();
                if (ctg.equals(prevCtgName)) {
                    if (start <= prevEnd) {
                        return false;
                    } else {
                        prevEnd = end;
                    }
                } else {
                    final int ctgIdx = dictionary.getSequenceIndex(ctg);
                    if (ctgIdx <= prevCtgIdx) {
                        return false;
                    } else {
                        prevCtgName = ctg;
                        prevCtgIdx = ctgIdx;
                        prevEnd = end;
                    }
                }
            }
            return true;
        }
    }

    private static class DragstrLocusCaseStratificator implements Collector<DragstrLocusCase, StratifiedDragstrLocusCases, StratifiedDragstrLocusCases> {

        private final int maxPeriod;
        private final int maxRepeats;

        private static DragstrLocusCaseStratificator make(final int maxPeriod, final int maxRepeats) {
            return new DragstrLocusCaseStratificator(maxPeriod, maxRepeats);
        }

        private DragstrLocusCaseStratificator(final int maxPeriod, final int maxRepeats) {
            this.maxPeriod = maxPeriod;
            this.maxRepeats = maxRepeats;
        }

        @Override
        public Supplier<StratifiedDragstrLocusCases> supplier() {
            return () -> new StratifiedDragstrLocusCases(maxPeriod, maxRepeats);
        }

        @Override
        public BiConsumer<StratifiedDragstrLocusCases, DragstrLocusCase> accumulator() {
            return StratifiedDragstrLocusCases::add;
        }

        @Override
        public BinaryOperator<StratifiedDragstrLocusCases> combiner() {
            return StratifiedDragstrLocusCases::addAll;
        }

        @Override
        public Function<StratifiedDragstrLocusCases, StratifiedDragstrLocusCases> finisher() {
            return a -> a;
        }

        @Override
        public Set<Characteristics> characteristics() {
            return EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED);
        }
    }

    /**
     * Stream collector class define to coalese several stratified locus case collections.
     */
    private static class DragstrLocusCaseCollector implements Collector<EquivalentReadSet, DragstrLocusCaseCollector, DragstrLocusCase> {

        private final DragstrLocus locus;
        private final long strStart;
        private final long strEnd;
        private final long strEndPlusOne;
        private final long paddedStrStart;
        private final long paddedStrEnd;

        private int n;
        private int k;
        private int minMQ;
        private int nSup;

        private DragstrLocusCaseCollector(final DragstrLocus locus, final long strStart,
                                          final long strEnd, final long paddedStrStart, final long paddedStrEnd) {
            this.locus = locus;
            this.strStart = strStart;
            this.strEnd = strEnd;
            this.strEndPlusOne = strEnd + 1;
            this.paddedStrStart = paddedStrStart;
            this.paddedStrEnd = paddedStrEnd;
            n = k = nSup = 0;
            minMQ = SAMRecord.UNKNOWN_MAPPING_QUALITY;
        }

        public static DragstrLocusCaseCollector create(final DragstrLocus locus, final int padding, final long contingLength) {
            Utils.nonNull(locus);
            Utils.validateArg(padding >= 0, "padding must be 0 or greater");
            Utils.validateArg(contingLength >= 1, "contig length must be strictly positive");
            final long strStart = locus.getStart();
            final long strEnd = locus.getEnd();
            final long paddedStrStart = Math.max(1, strStart - padding);
            final long paddedStrEnd = Math.min(contingLength, strEnd + padding);
            return new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd);
        }

        @Override
        public Supplier<DragstrLocusCaseCollector> supplier() {
            return () -> new DragstrLocusCaseCollector(locus, strStart, strEnd, paddedStrStart, paddedStrEnd);
        }

        @Override
        public BiConsumer<DragstrLocusCaseCollector, EquivalentReadSet> accumulator() {
            return DragstrLocusCaseCollector::collect;
        }

        /**
         * Adds the relevant stats of the read to the collector based on its overlap
         * with the STR and the presence of indel events.
         *
         * Assumes that the read is mapped to the same contig as the locus, se we don't test
         * for that.
         * @param eset the read to collect.
         */
        private void collect(final EquivalentReadSet eset) {
            final int readStart = eset.getStart();
            final int readEnd = eset.getEnd();
            final int size = eset.size();
            if (readStart <= paddedStrStart && readEnd >= paddedStrEnd) {
                if (eset.isSupplementaryAlignment()) {
                    nSup += size;
                }
                minMQ = Math.min(minMQ, eset.getMappingQuality());
                int refPos = readStart;
                // int lengthDiff = 0;
                for (final CigarElement ce : eset.getCigar()) {
                        final CigarOperator op = ce.getOperator();
                        final int length = ce.getLength();
                        if (op == CigarOperator.I && refPos >= strStart && refPos <= strEndPlusOne) {
                            k += size;
                            //lengthDiff += length;
                        } else if (op == CigarOperator.D && refPos + length - 1 >= strStart && refPos <= strEnd) {
                            k += size;
                            //lengthDiff -= length;
                        }
                        // update refPos and quick end if we have gone beyond the end of the STR.
                        if ((refPos += op.consumesReferenceBases() ? length : 0)  > strEndPlusOne) {
                            break;
                        }
                    }
                    n += size;
                }
        }

        private DragstrLocusCaseCollector combineWith(final DragstrLocusCaseCollector other) {
            Utils.validateArg(other.locus == this.locus, "collectors at different loci cannot be convined");
            final DragstrLocusCaseCollector result = new DragstrLocusCaseCollector(locus, strStart,
                    strEnd, paddedStrStart, paddedStrEnd);
            result.k = k + other.k;
            result.n = n + other.n;
            result.nSup = nSup + other.nSup;
            result.minMQ = Math.min(minMQ, other.minMQ);
            return result;
        }

        private DragstrLocusCase finish() {
            return DragstrLocusCase.create(locus, n, k, minMQ, nSup);
        }

        @Override
        public BinaryOperator<DragstrLocusCaseCollector> combiner() {
            return DragstrLocusCaseCollector::combineWith;
        }

        @Override
        public Function<DragstrLocusCaseCollector, DragstrLocusCase> finisher() {
            return DragstrLocusCaseCollector::finish;
        }

        @Override
        public Set<Characteristics> characteristics() {
            return Collections.emptySet();
        }
    }

    /**
     * Generates a stream of locus cases for a interval/shard.
     * <p>
     *     The returned stream in turn feeds on two streams: (a) stream of reads for the interval and (b)
     *     the str-table entry loci for that interval.
     * </p>
     * <p>
     *     This are processed in parallel in position order within the shard so that for every given str-table entry we get
     *     all the relevant reads that overlap the STR.
     * </p>
     * @param shard the target shard.
     * @param reads a stream on the reads in the input shard.
     * @param loci a stream on the loci in the input shard.
     * @return never {@code null}, perhaps an empty stream.
     */
    private Stream<DragstrLocusCase> streamShardCasesStats(final SimpleInterval shard, final Stream<GATKRead> reads, final Stream<DragstrLocus> loci) {
        final int contigLength = dictionary.getSequence(shard.getContig()).getSequenceLength();

        return StreamSupport.stream(new Spliterator<DragstrLocusCase>() {

            private final Spliterator<GATKRead> readSpliterator = reads.spliterator();
            private final Spliterator<DragstrLocus> lociSpliterator = loci.spliterator();
            private final ShardReadBuffer readBuffer = new ShardReadBuffer();

            private GATKRead read;
            private DragstrLocus locus;

            /**
             * Move forward in the reads stream.
             * <p>
             * If it returns {@code true} the read to process is place in {@link #read};
             * </p>
             *
             * @return {@code true} iff there is one more read to process.
             */
            private boolean advanceRead() {
                return readSpliterator.tryAdvance(read -> this.read = read);
            }

            /**
             * Move forward in the locus stream.
             * <p>
             * If it returns {@code true} the locus to process is place in {@link #locus};
             * </p>
             *
             * @return {@code true} iff there is one more loci to process.
             */
            private boolean advanceLocus() {
                return lociSpliterator.tryAdvance(locus -> this.locus = locus);
            }

            @Override
            public boolean tryAdvance(final Consumer<? super DragstrLocusCase> action) {
                if (advanceLocus()) { // if true, sets 'locus' to the next in the stream.
                    readBuffer.removeUpstreamFrom((int) locus.getStart()); // flush the buffer from up-stream reads that we won't need again.
                    // We keep reading reads into the buffer until we reach the first downstream
                    // from the current subject.
                    while (advanceRead()) {  // if true sets 'read' to the next in the stream.
                        readBuffer.add(read.getAssignedStart(), read.getEnd(), read);
                        if (read.getAssignedStart() > locus.getEnd()) {
                            break;
                        }
                    }
                    // Now we compose the case given the locus and all overlapping reads.
                    final List<EquivalentReadSet> reads = readBuffer.overlapping((int) locus.getStart(), (int) locus.getEnd());
                    final DragstrLocusCase newCase = composeDragstrLocusCase(locus, reads, contigLength);
                    action.accept(newCase);
                    return true;
                } else { // no more loci in the stream, we are finished.
                    return false;
                }
            }

            @Override
            public Spliterator<DragstrLocusCase> trySplit() {
                return null;
            }

            @Override
            public long estimateSize() {
                return 0;
            }

            @Override
            public int characteristics() {
                return 0;
            }
        }, false);
    }

    private static void outputSiteDetails(final PrintWriter writer, final DragstrLocusCase caze, final String fate) {
        writer.println(Utils.join("\t", "" + caze.getLocus().getChromosomeIndex() + ':' + (caze.getLocus().getStart() - 1),
                caze.getLocus().getPeriod(),
                caze.getLocus().getRepeats(),
                caze.getDepth(),
                caze.getIndels(),
                caze.getMinMQ(),
                caze.getNSup(),
                fate));
    }

    private Stream<org.broadinstitute.hellbender.utils.read.GATKRead> readStream(final ReadsDataSource source, final SimpleInterval interval) {
        final Stream<org.broadinstitute.hellbender.utils.read.GATKRead> unfiltered = interval == null ? Utils.stream(source) : Utils.stream(source.query(interval));
        return unfiltered
                .filter(read -> (read.getFlags() & DISCARD_FLAG_VALUE) == 0 && read.getAssignedStart() <= read.getEnd())
                .map(EXTENDED_MQ_READ_TRANSFORMER);
    }

    // flags for the reads that are to be discarded from analyses.
    private static final EnumSet<SAMFlag> DISCARD_FLAGS = EnumSet.of(
            SAMFlag.READ_UNMAPPED, SAMFlag.SECONDARY_ALIGNMENT, SAMFlag.READ_FAILS_VENDOR_QUALITY_CHECK);
    private static final int DISCARD_FLAG_VALUE = DISCARD_FLAGS.stream().mapToInt(SAMFlag::intValue).sum();

    private DragstrLocusCase composeDragstrLocusCase(final DragstrLocus locus, final List<EquivalentReadSet> rawReads, final long contigLength) {
        return rawReads.stream()
                .collect(DragstrLocusCaseCollector.create(locus, hyperParameters.strPadding, contigLength));
    }

    /**
     * Sets of reads that for the intent and proposes of this model are equivalent assuming that they are mapped on the same
     * location; we don't check for that.
     */
    private static class EquivalentReadSet {
        private GATKRead example;
        private int size;

        public boolean belongs(final GATKRead read) {
            return (read.isSupplementaryAlignment() == example.isSupplementaryAlignment()
                    && read.getMappingQuality() == example.getMappingQuality()
                    && read.getCigar().equals(example.getCigar()));
        }

        public static int hashCode(final GATKRead read) {
            return (((Boolean.hashCode(read.isSupplementaryAlignment()) * 31) + read.getMappingQuality() * 31) + read.getCigar().hashCode());
        }

        public int hashCode() {
            return hashCode(example);
        }

        private EquivalentReadSet(final GATKRead read) {
            example = read;
            size  = 1;
        }

        public static EquivalentReadSet of(final GATKRead read) {
            Utils.nonNull(read);
            return new EquivalentReadSet(read);
        }

        public void increase(final int inc) {
            size += inc;
        }
        public int getStart() {
            return example.getStart();
        }

        public int getEnd() {
            return example.getEnd();
        }

        public boolean isSupplementaryAlignment() {
            return example.isSupplementaryAlignment();
        }

        public int size() {
            return size;
        }

        public int getMappingQuality() {
            return example.getMappingQuality();
        }

        public Iterable<? extends CigarElement> getCigar() {
            return example.getCigar();
        }
    }

    /**
     * Simple read-buffer implementation.
     */
    private static class ShardReadBuffer extends IntervalTree<Int2ObjectMap<EquivalentReadSet>> {

        private static Int2ObjectMap<EquivalentReadSet> mergeEquivalentReadSets(final Int2ObjectMap<EquivalentReadSet> left,
                                                                                final Int2ObjectMap<EquivalentReadSet> right) {
            // receiver is the map that will collect the output, perhaps one of the inputs.
            // donor is the other map.
            // 1 size maps are unmodifiable singletons so they only can be donors.
            final Int2ObjectMap<EquivalentReadSet> receiver, donor;
            if (left.size() > 1) { //
                receiver = left; donor = right;
            } else if (right.size() > 1) {
                receiver = right; donor = left;
            } else {
                receiver = new Int2ObjectOpenHashMap<>(left);
                donor = right;
            }
            for (final EquivalentReadSet e2 : donor.values()) {
                final EquivalentReadSet e1 = receiver.get(e2.hashCode());
                if (e1 == null) { // if not in the receiver we simply copy it over.
                    receiver.put(e2.hashCode(), e2);
                } else { // if present we increase the count.
                    e1.increase(e2.size());
                }
            }
            return receiver;
        }

        public void add(final int start, final int end, final GATKRead elem) {
            merge(start, end, Int2ObjectMaps.singleton(EquivalentReadSet.hashCode(elem), EquivalentReadSet.of(elem)),
                    ShardReadBuffer::mergeEquivalentReadSets);
        }

        void removeUpstreamFrom(final int start) {
            final Iterator<Node<Int2ObjectMap<EquivalentReadSet>>> it = iterator();
            while (it.hasNext()) {
                final Node<Int2ObjectMap<EquivalentReadSet>> node = it.next();
                if (node.getStart() >= start) {
                    break;
                } else if (node.getEnd() < start) {
                    it.remove();
                }
            }
        }

        public List<EquivalentReadSet> overlapping(final int start, final int end) {
            Iterator<Node<Int2ObjectMap<EquivalentReadSet>>> it = this.overlappers(start, end);
            if (!it.hasNext()) {
                return Collections.emptyList();
            } else {
                final List<EquivalentReadSet> result = new ArrayList<>();
                do {
                    final Node<Int2ObjectMap<EquivalentReadSet>> node = it.next();
                    result.addAll(node.getValue().values());
                } while (it.hasNext());
                return result;
            }
        }
    }

    /**
     * Simple 2-int tuple to hold a period and repeat-length pair.
     */
    private static class PeriodAndRepeatLength {
        private final int period;
        private final int repeatLength;

        private PeriodAndRepeatLength(final int period, final int repeatLength) {
            this.period = period;
            this.repeatLength = repeatLength;
        }

        private static PeriodAndRepeatLength of(final int period, final int repeat) {
            return new PeriodAndRepeatLength(period, repeat);
        }

        @Override
        public String toString() {
            return "(" +  period + ',' + repeatLength + ')';
        }
    }
}
